import os
import time
import math
from xml.sax.xmlreader import InputSource
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
from models.linear import Classifier
from utils import *
from losses import *
from dataset import get_vic_dataloader
from models import build_model

import pprint
import logging
import argparse
from methods import *

import torch
import torch.nn as nn
import torch.nn.functional as F


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Visual Classification Training')

    # dataset setting
    parser.add_argument('--root', default='../database/', help='the roor dir of dataset')
    parser.add_argument('--dataset', default='CIFAR10', help='dataset setting')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size')
    
    # training setting
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
    parser.add_argument('--need_linear', default=False, action='store_false')
    parser.add_argument('--last_relu', default=False, action='store_true')
    parser.add_argument('--mode', default='linear', type=str, help='mode of the last layer')
    parser.add_argument('--loss', default='ce', type=str)
    parser.add_argument('--temp', default=1.0, type=float)
    parser.add_argument('--alpha', default=1.0, type=float)
    parser.add_argument('--adjusted', default=False, action='store_true')
    parser.add_argument('--eps', default=0.8, type=float)
    parser.add_argument('--reg', default=0.0, type=float, help='the coefficient of feature norm regularization')

    parser.add_argument('--mixup', default=False, action='store_true')
    parser.add_argument('--gamma', type=float, default=1.0)
    parser.add_argument('--weight', default=None)
    parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ')
    parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
    parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
    parser.add_argument('--deterministic', default=True, type=bool)

    # optimizer setting
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)', dest='weight_decay')
    parser.add_argument('--scheduler', default='cos', type=str, help='The scheduler')

    parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    # parser.add_argument('--save-path', default='/data/IMB/RESULTS', type=str)
    parser.add_argument('--save-path', default='./RESULTS', type=str)
    parser.add_argument('--root_log', type=str, default='vic_log')
    parser.add_argument('--root_model', type=str, default='checkpoint')

    return parser.parse_args()


def train(model, classifier, loader, criterion, optimizer, epoch, writer, args, logger):
    model.train()
    losses = AverageMeter('train/loss', ':.4e')
    top1 = AverageMeter('train/top1', ':.4e')
    top5 = AverageMeter('train/top1', ':.4e')
    margins = AverageMeter('model/margin', ':.4e')
    norms = AverageMeter('model/norm', ':.4e')
    feature_norms = AverageMeter('train/feature_norm', ':.4e')
    ratios = AverageMeter('model/ratio', ':.4e')
    batch_time = AverageMeter('Time', ':6.3f')
    end = time.time()
    for i, (inp, target) in enumerate(loader):
        inp, target = inp.cuda(args.gpu), target.cuda(args.gpu)
        if args.mixup:
            inp, target_a, target_b, lam = mixup_data(inp, target, alpha=args.gamma)
        features = model(inp)
        if args.loss.lower() == 'normface':
            out = classifier(features, target, adjusted=False, eps=args.eps)
        else:
            out = classifier(features, target, adjusted=args.adjusted, eps=args.eps)

        feature_norm = torch.sum(features**2)
        if args.loss.lower() != 'normface':
            loss = criterion(out, target) if not args.mixup else mixup_criterion(criterion, out, target_a, target_b, lam)
        else:
            if args.adjusted:
                feat_norm = torch.norm(features, dim=-1)
                loss = criterion(out, target, weight=feat_norm.detach())
            else:
                loss = criterion(out, target)

        if args.reg > 0:
            loss = loss + args.reg * feature_norm

        acc1, acc5 = accuracy(out, target, topk=(1, 5))
        margin, norm, ratio = classifier.margin()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        losses.update(loss.item(), inp.size(0))
        top1.update(acc1.item(), inp.size(0))
        top5.update(acc5.item(), inp.size(0))
        margins.update(margin, inp.size(0))
        norms.update(norm, inp.size(0))
        feature_norms.update(feature_norm.item(), inp.size(0))
        ratios.update(ratio, inp.size(0))
    
        if i % args.print_freq == 0:
            output = ('Epoch: [{0}/{1}][{2}/{3}], lr: {lr:.5f}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                      'Margin {margin.val:.3f} ({margin.avg:.3f})\t'
                      'FeatNorm {featnorm.val:.3f} ({featnorm.avg:.3f})\t'
                      'Norm {norm.val:.3f} ({norm.avg:.3f})\t'
                      'Ratio {ratio.val:.3f} ({ratio.avg:.3f})\t'.format(
                epoch, args.epochs, i, len(loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5, margin=margins, norm=norms, featnorm=feature_norms, ratio=ratios, lr=optimizer.param_groups[0]['lr']))
            logger.info(output)

    writer.add_scalar('train/loss', losses.avg, epoch)
    writer.add_scalar('train/top1', top1.avg, epoch)
    writer.add_scalar('train/top5', top5.avg, epoch)
    writer.add_scalar('train/feat_norm', feature_norms.avg, epoch)
    writer.add_scalar('model/margin', margins.avg, epoch)
    writer.add_scalar('model/norm', norms.avg, epoch)
    writer.add_scalar('model/ratio', ratios.avg, epoch)
    writer.add_scalar('train/lr', optimizer.param_groups[-1]['lr'], epoch)



def validate(loader, model, classifier, args):
    model.eval()
    val_top1 = AverageMeter('val/top1', ':.4e')
    val_top5 = AverageMeter('val/top1', ':.4e')
    with torch.no_grad():
        for input, target in loader:
            if args.gpu is not None:
                input = input.cuda(args.gpu)
            target = target.cuda(args.gpu)
            output = classifier(model(input))
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            val_top1.update(acc1.item(), input.size(0))
            val_top5.update(acc5.item(), input.size(0))
    return val_top1.avg, val_top5.avg



def main_worker(args):
    seed = args.seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    dataloaders = get_vic_dataloader(args)
    train_loader = dataloaders['train']
    val_loader = dataloaders['val']

    model = build_model(args).cuda(args.gpu)
    classifier = Classifier(in_features=args.in_features, num_classes=args.num_classes, mode=args.mode, weight=args.weight, scale=args.temp).cuda(args.gpu)

    optimizer = torch.optim.SGD(
        [
            {'params': model.parameters()},
            {'params': classifier.parameters(), 'lr': args.lr, 'momentum': args.momentum , 'weight_decay': args.weight_decay},
        ], 
        lr = args.lr, momentum=args.momentum, weight_decay=args.weight_decay
    )
    if args.scheduler == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0.0)
    elif args.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    
    if args.loss in ['CE', 'ce']:
        criterion = nn.CrossEntropyLoss()
    elif args.loss in ['asm', 'ASM']:
        criterion = AveragedSampleMarginLoss(alpha=args.alpha)
    
    store_name = args.dataset + '/' + '_'.join([str(i) for i in [args.loss, args.temp, args.mode, args.scheduler, args.lr, args.weight_decay, args.adjusted, args.eps, args.last_relu, args.reg]])
    args.store_name = os.path.join(args.save_path, args.root_log, store_name)
    tf_writer = SummaryWriter(log_dir=args.store_name)

    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=os.path.join(args.store_name, 'log.txt'), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    logger.info('\n' + pprint.pformat(args))

    best_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        train(model, classifier, train_loader, criterion, optimizer, epoch, tf_writer, args, logger)
        scheduler.step()
        acc1, acc5 = validate(val_loader, model, classifier, args)
        output = '\nEpoch [{}]: Validation\t Prec@1={:.4f}\t Prec@5={:.4f}\n'.format(
            epoch, acc1, acc5) 
        logger.info(output)
        is_best = acc1 > best_acc
        best_acc = max(acc1, best_acc)
        tf_writer.add_scalar('val/best_acc', best_acc, epoch)
        tf_writer.add_scalar('val/top1', acc1, epoch)
        tf_writer.add_scalar('val/top5', acc5, epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'classifier': classifier.state_dict()
            },
            is_best,
            args.store_name
        )

def main():
    args = parse_args()
    args.in_features = 512
    if args.dataset.lower() == 'mnist':
        args.num_classes = 10
        args.in_features = 128
    elif args.dataset.lower() == 'cifar10':
        args.num_classes = 10
    elif args.dataset.lower() == 'cifar100':
        args.num_classes = 100
    args.alpha = 1. / (args.num_classes - 1)
    args.c_weight_decay = 1.0 / ((args.num_classes - 1) * math.sqrt(args.batch_size/args.num_classes))

    if args.mode.startswith('fix') or args.mode.startswith('init'):
        path = './prototypes/weight1' if args.last_relu else './prototypes/weight0'
        args.weight = torch.Tensor(np.load(path + '_%dx%d.npy' % (args.num_classes, args.in_features)))
    main_worker(args)


if __name__ == '__main__':
    main()
